Deep image prior特征反转部分feature_inversion的实现

74次阅读
没有评论

共计 7904 个字符,预计需要花费 20 分钟才能阅读完成。

提醒:本文最后更新于 2024-08-30 15:39,文中所关联的信息可能已发生改变,请知悉!

导入库

from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline

import argparse
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
from models import *

import torch
import torch.optim

from utils.feature_inversion_utils import *
from utils.perceptual_loss.perceptual_loss import get_pretrained_net
from utils.common_utils import *

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor

PLOT = True
fname = './data/feature_inversion/building.jpg'

pretrained_net = 'alexnet_caffe' # 'alexnet_caffe' # 'vgg19_caffe'
layers_to_use = 'fc6' # comma-separated string of layer names e.g. 'fc6,fc7'

载入预训练网络

# cnn = get_pretrained_net(pretrained_net).type(dtype)
cnn = torch.load('./data/feature_inversion/alexnet-torch_py3.pth').type(dtype)

opt_content = {'layers': layers_to_use, 'what':'features'}

# Remove the layers we don't need 
keys = [x for x in cnn._modules.keys()]
max_idx = max(keys.index(x) for x in opt_content['layers'].split(','))
for k in keys[max_idx+1:]:
    cnn._modules.pop(k)

print(cnn)

其中,源码第一行

cnn = get_pretrained_net(pretrained_net).type(dtype)

中的 get_pretrained_net 函数如下:

def get_pretrained_net(name):
    """Loads pretrained network"""
    if name == 'alexnet_caffe':
        if not os.path.exists('alexnet-torch_py3.pth'):
            print('Downloading AlexNet')
            os.system('wget -O alexnet-torch_py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/77xSWvrDN0CiQtK/download')
            # os.system('wget -O alexnet-torch_py3.pth --no-check-certificate -nc https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth')
        return torch.load('alexnet-torch_py3.pth')
    elif name == 'vgg19_caffe':
        if not os.path.exists('vgg19-caffe-py3.pth'):
            print('Downloading VGG-19')
            os.system('wget -O vgg19-caffe-py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/HPcOFQTjXxbmp4X/download')

        vgg = get_vgg19_caffe()

        return vgg
    elif name == 'vgg16_caffe':
        if not os.path.exists('vgg16-caffe-py3.pth'):
            print('Downloading VGG-16')
            os.system('wget -O vgg16-caffe-py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/TUZ62HnPKWdxyLr/download')

        vgg = get_vgg16_caffe()

        return vgg
    elif name == 'vgg19_pytorch_modified':
        # os.system('wget -O data/feature_inversion/vgg19-caffe.pth --no-check-certificate -nc https://www.dropbox.com/s/xlbdo688dy4keyk/vgg19-caffe.pth?dl=1')

        model = VGGModified(vgg19(pretrained=False), 0.2)
        model.load_state_dict(torch.load('vgg_pytorch_modified.pkl')['state_dict'])

        return model
    else:
        assert False

其中,例如第 6 行及其类似语句

os.system('wget -O alexnet-torch_py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/77xSWvrDN0CiQtK/download')

无法正确执行,所以我弃用这个函数,改为本地导入,如下:

cnn = torch.load('./data/feature_inversion/alexnet-torch_py3.pth').type(dtype)

最后能正确执行

E:\anaconda3\lib\site-packages\torch\serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.container.Sequential' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
E:\anaconda3\lib\site-packages\torch\serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
E:\anaconda3\lib\site-packages\torch\serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.activation.ReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
E:\anaconda3\lib\site-packages\torch\serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.normalization.CrossMapLRN2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
E:\anaconda3\lib\site-packages\torch\serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.pooling.MaxPool2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
E:\anaconda3\lib\site-packages\torch\serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
E:\anaconda3\lib\site-packages\torch\serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.dropout.Dropout' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
Sequential((conv1): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4))
  (relu1): ReLU()
  (norm1): CrossMapLRN2d(5, alpha=0.0001, beta=0.75, k=1)
  (pool1): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0), dilation=1, ceil_mode=True)
  (conv2): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=2)
  (relu2): ReLU()
  (norm2): CrossMapLRN2d(5, alpha=0.0001, beta=0.75, k=1)
  (pool2): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0), dilation=1, ceil_mode=True)
  (conv3): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu3): ReLU()
  (conv4): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2)
  (relu4): ReLU()
  (conv5): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2)
  (relu5): ReLU()
  (pool5): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0), dilation=1, ceil_mode=True)
  (torch_view): View()
  (fc6): Linear(in_features=9216, out_features=4096, bias=True)
)

载入图片

# Target imsize 
imsize = 227 if pretrained_net == 'alexnet' else 224

# Something divisible by a power of two
imsize_net = 256

# VGG and Alexnet need input to be correctly normalized
preprocess, deprocess = get_preprocessor(imsize), get_deprocessor()

img_content_pil, img_content_np  = get_image(fname, imsize)
img_content_prerocessed = preprocess(img_content_pil)[None,:].type(dtype)

img_content_pil

Deep image prior 特征反转部分 feature_inversion 的实现

设置匹配器和网络

matcher_content = get_matcher(cnn, opt_content)

matcher_content.mode = 'store'
cnn(img_content_prerocessed);
INPUT = 'noise'
pad = 'zero' # 'refection'
OPT_OVER = 'net' #'net,input'
OPTIMIZER = 'adam' # 'LBFGS'
LR = 0.001

num_iter = 3100

input_depth = 32
net_input = get_noise(input_depth, INPUT, imsize_net).type(dtype).detach()
net = skip(input_depth, 3, num_channels_down = [16, 32, 64, 128, 128, 128],
                           num_channels_up =   [16, 32, 64, 128, 128, 128],
                           num_channels_skip = [4, 4, 4, 4, 4, 4],   
                           filter_size_down = [7, 7, 5, 5, 3, 3], filter_size_up = [7, 7, 5, 5, 3, 3], 
                           upsample_mode='nearest', downsample_mode='avg',
                           need_sigmoid=True, pad=pad, act_fun='LeakyReLU').type(dtype)

# Compute number of parameters
s  = sum(np.prod(list(p.size())) for p in net.parameters())
print ('Number of params: %d' % s)

迭代

def closure():

    global i

    out = net(net_input)[:, :, :imsize, :imsize]

    cnn(vgg_preprocess_var(out))
    total_loss =  sum(matcher_content.losses.values())
    total_loss.backward()

    print ('Iteration %05d    Loss %.3f' % (i, total_loss.item()), '\r', end='')
    if PLOT and i % 200 == 0:
        out_np = np.clip(torch_to_np(out), 0, 1)
        plot_image_grid([out_np], 3, 3);

    i += 1

    return total_loss
i=0
matcher_content.mode = 'match'
p = get_params(OPT_OVER, net, net_input)
optimize(OPTIMIZER, p, closure, LR, num_iter)

由于计算机性能问题,只能跑一段有限的时间,得到下面结果,loss = 0.112

Deep image prior 特征反转部分 feature_inversion 的实现

正文完
 0
icvuln
版权声明:本站原创文章,由 icvuln 于2022-07-31发表,共计7904字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
评论(没有评论)